import torch
import numpy as np
import matplotlib.pyplot as plt
np.set_printoptions(threshold=np.inf)
attn_map = np.load("./src/attn_weight.npy")
attn_feat = np.load("./src/attn_feats.npy")
for i in range(len(attn_feat)):
    plt.imshow(attn_feat[i].T)
    plt.savefig(f"./img/attn_feat_{i}_Dup_CNN_attn.png")
    plt.close()
# print(attn_weight.shape)
# attn_map = torch.softmax(torch.from_numpy(attn_weight), dim=0)
# attn_map = attn_map.numpy()
print(attn_map.T)
plt.imshow(attn_map.T)
plt.savefig("./img/attn_map_Dup_CNN_attn.png")